# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools, itertools as it
from abc import ABCMeta, abstractmethod
from hysop.tools.htypes import check_instance, first_not_None
from hysop.tools.numpywrappers import npw
from hysop.tools.misc import next_pow2, upper_pow2
from hysop.backend.device.kernel_autotuner_config import KernelAutotunerConfig
from hysop.backend.device.codegen.structs.mesh_info import MeshInfoStruct
from hysop.fields.cartesian_discrete_field import CartesianDiscreteScalarFieldView
[docs]
class AutotunableKernel(metaclass=ABCMeta):
def __init__(
self, autotuner_config, build_opts, dump_src=None, symbolic_mode=None, **kwds
):
super().__init__(**kwds)
self._check_build_configuration(autotuner_config, build_opts)
self.autotuner_config = autotuner_config
self.build_opts = build_opts
self.dump_src = first_not_None(dump_src, autotuner_config.debug)
self.symbolic_mode = first_not_None(symbolic_mode, autotuner_config.debug)
[docs]
def custom_hash(self, *args, **kwds):
HASH_DEBUG = self.autotuner_config.dump_hash_logs
assert args or kwds, "no arguments to be hashed."
def _hash_arg(a):
s = ""
if a is None:
s += "\nNone"
h = hash("None")
elif a is Ellipsis:
s += "\nEllipsis"
h = hash("Ellipsis")
elif isinstance(a, str):
if HASH_DEBUG:
s += f"\n>HASHING STR: {a}"
h = hash(a)
if HASH_DEBUG:
s += f"\n<HASHED STR: hash={h}"
elif isinstance(a, list):
if HASH_DEBUG:
s += "\n>HASHING LIST:"
h = hash(tuple(_hash_arg(x) for x in a))
if HASH_DEBUG:
s += f"\n<HASHED LIST: hash={h}"
elif isinstance(a, tuple):
if HASH_DEBUG:
s += "\n>HASHING TUPLE:"
h = hash(tuple(_hash_arg(x) for x in a))
if HASH_DEBUG:
s += f"\n<HASHED TUPLE: hash={h}"
elif isinstance(a, (set, frozenset)):
if HASH_DEBUG:
s += "\n>HASHING SET:"
h = hash(tuple(_hash_arg(x) for x in sorted(a)))
if HASH_DEBUG:
s += f"\n<HASHED SET: hash={h}"
elif isinstance(a, dict):
if HASH_DEBUG:
s += "\n>HASHING DICT:"
h = hash(
tuple((_hash_arg(k), _hash_arg(a[k])) for k in sorted(a.keys()))
)
if HASH_DEBUG:
s += f"\n<HASHED DICT: hash={h}"
elif isinstance(a, npw.ndarray):
if HASH_DEBUG:
s += "\n>HASHING NDARRAY:"
assert a.ndim <= 1
assert a.size < 17, "Only parameters up to size 16 are allowed."
hh, ss = self.custom_hash(a.tolist())
h = hh
s += ss
if HASH_DEBUG:
s += f"\n>HASHED NDARRAY: hash={h}"
else:
h = hash(a)
if HASH_DEBUG:
s += f"\n>HASHED UNKNOWN TYPE {type(a)}: hash={h}"
assert h is not id(a), type(a)
return h, s
def _hash_karg(k, v):
s = ""
if k == "mesh_info_vars":
# for mesh infos we just hash the code generated constants that
# may alter the code branching.
if HASH_DEBUG:
s += "\n<HASHING MESHINFO"
from hysop.backend.device.codegen.base.variables import CodegenStruct
check_instance(v, dict, keys=str, values=CodegenStruct)
mesh_infos = tuple(str(v[k]) for k in sorted(v.keys()))
h = hash(mesh_infos)
if HASH_DEBUG:
s += "\n MESH INFOS:"
for mi in mesh_infos:
s += "\n " + mi
s += f"\n>HASHED MESHINFO: hash={h}"
return h, s
elif k == "expr_info":
# for expr infos we just hash the continous and discrete expressions
# and some additional variables
if HASH_DEBUG:
s += "\n>HASHING EXPR_INFO:"
exprs = tuple(str(e) for e in v.exprs)
exprs += tuple(str(e) for e in v.dexprs)
extras = (v.name, v.direction, v.has_direction, v.dt_coeff, v.kind)
for k in sorted(
v.min_ghosts_per_components.keys(), key=lambda x: x.name
):
extras += (k.name, _hash_arg(v.min_ghosts_per_components[k]))
for mem_obj_key in (
"input_arrays",
"output_arrays",
"input_buffers",
"output_buffers",
"input_params",
"output_params",
):
mem_objects = getattr(v, mem_obj_key)
for k in sorted(mem_objects, key=lambda x: x[0]):
assert hasattr(mem_objects[k], "short_description"), type(
mem_objects[k]
).__mro__
extras += (k, hash(mem_objects[k].short_description()))
hh, ss = self.custom_hash(exprs + extras)
h = hh
s += ss
if HASH_DEBUG:
s += "\n EXPRESSIONS:"
for e in exprs:
s += f"\n {e} {type(e)}"
s += f"\n with hash {self.custom_hash(e)[1]}"
s += "\n EXTRAS:"
for e in extras:
s += f"\n {e} {type(e)}"
s += f"\n with hash {self.custom_hash(e)[1]}"
s += f"\n<HASHED EXPR_INFO: hash={h}"
return h, s
else:
msg = f"Unknown custom hash key '{k}'."
raise KeyError(msg)
def hash_all(*args, **kwds):
h, s = None, None
if args:
h, s = _hash_arg(args[0])
if HASH_DEBUG:
s += f"\nHASHED ARGUMENT 0: {h}"
for i, arg in enumerate(args[1:]):
hh, ss = _hash_arg(arg)
h ^= hh
if HASH_DEBUG:
s += ss
s += f"\nHASHED ARGUMENT {i}: {h}"
if kwds:
items = tuple(sorted(kwds.items(), key=lambda x: x[0]))
if h is None:
h, s = _hash_karg(*items[0])
else:
hh, ss = _hash_karg(*items[0])
h ^= hh
if HASH_DEBUG:
s += ss
s += f"\nHASHED KWD 0: {h}"
for i, it in enumerate(items[1:]):
hh, ss = _hash_karg(*it)
h ^= hh
if HASH_DEBUG:
s += ss
s += f"\nHASHED KWD {i}: {h}"
return h, s
h, s = hash_all(*args, **kwds)
return h, s
[docs]
@abstractmethod
def autotune(
self, name, kernel_args, force_verbose=False, force_debug=False, **extra_kwds
):
"""Autotune this kernel with given name and extra_kwds."""
pass
[docs]
@abstractmethod
def max_device_work_dim(self):
"""Maximum dimensions that specify the global and local work-item IDs."""
pass
[docs]
@abstractmethod
def max_device_work_group_size(self):
"""Return the maximum number of work items allowed by the device."""
pass
[docs]
@abstractmethod
def max_device_work_item_sizes(self):
"""
Maximum number of work-items that can be specified in each dimension
of the work-group.
"""
pass
[docs]
@abstractmethod
def compute_args_mapping(self, extra_kwds, extra_parameters):
"""
Return arguments mapping which is a dictionnary
with arguments names as keys and tuples a values.
Tuples should contain (arg_position, arg_type(s)) with
arg_position being an int and arg_type(s) a type or
tuple of types which will be checked against.
"""
pass
[docs]
def compute_parameters(self, extra_kwds):
"""Register extra parameters to optimize."""
return AutotunerParameterConfiguration()
[docs]
def compute_work_bounds(
self,
max_kernel_work_group_size,
preferred_work_group_size_multiple,
extra_parameters,
extra_kwds,
work_size=None,
work_dim=None,
min_work_load=None,
max_work_load=None,
):
"""
Configure work_bounds (work_dim, work_size, max_work_load).
Return a WorkBoundsConfiguration object.
"""
check_instance(max_kernel_work_group_size, int)
check_instance(preferred_work_group_size_multiple, int)
check_instance(extra_parameters, dict, keys=str)
check_instance(extra_kwds, dict, keys=str)
assert max_kernel_work_group_size > 0, max_kernel_work_group_size
assert (
preferred_work_group_size_multiple > 0
), preferred_work_group_size_multiple
msg = "FATAL ERROR: Could not extract {} from keyword arguments, "
msg += "extra_parameters and extra_kwds."
msg += f"\nFix {type(self)}::compute_work_bounds()."
work_dim = first_not_None(
work_dim,
extra_parameters.get("work_dim", None),
extra_kwds.get("work_dim", None),
)
max_work_dim = self.max_device_work_dim()
if work_dim is None:
msg = msg.format("work_dim")
raise RuntimeError(msg)
elif work_dim > max_work_dim:
msg = "Got work_dim {} but maximum supported by device is {}."
msg = msg.format(work_dim, max_work_dim)
raise ValueError(msg)
work_size = first_not_None(
work_size,
extra_parameters.get("work_size", None),
extra_kwds.get("work_size", None),
)
if work_size is None:
msg = msg.format("work_size")
raise RuntimeError(msg)
min_work_load = first_not_None(
min_work_load,
extra_parameters.get("min_work_load", None),
extra_kwds.get("min_work_load", None),
(1,) * work_dim,
)
max_work_load = first_not_None(
max_work_load,
extra_parameters.get("max_work_load", None),
extra_kwds.get("max_work_load", None),
min_work_load,
)
assert min_work_load is not None
assert max_work_load is not None
max_device_work_dim = self.max_device_work_dim()
max_device_work_group_size = self.max_device_work_group_size()
max_device_work_item_sizes = self.max_device_work_item_sizes()
max_work_group_size = min(
max_device_work_group_size, max_kernel_work_group_size
)
work_bounds = AutotunerWorkBoundsConfiguration(
work_dim=work_dim,
work_size=work_size,
min_work_load=min_work_load,
max_work_load=max_work_load,
max_device_work_dim=max_device_work_dim,
max_device_work_group_size=max_work_group_size,
max_device_work_item_sizes=max_device_work_item_sizes,
preferred_work_group_size_multiple=preferred_work_group_size_multiple,
)
return work_bounds
[docs]
def compute_work_candidates(
self, work_bounds, work_load, extra_parameters, extra_kwds
):
"""
Configure work (global_size, local_size candidates) given an
AutotunerWorkBoundsConfiguration instance and a work_load.
Return a OpenClWorkConfiguration instance.
"""
check_instance(work_bounds, AutotunerWorkBoundsConfiguration)
check_instance(
work_load, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim
)
check_instance(extra_parameters, dict, keys=str)
check_instance(extra_kwds, dict, keys=str)
global_work_size = (work_bounds.work_size + work_load - 1) // work_load
(min_wg_size, max_wg_size) = self.compute_min_max_wg_size(
work_bounds=work_bounds,
work_load=work_load,
global_work_size=global_work_size,
extra_parameters=extra_parameters,
extra_kwds=extra_kwds,
)
work = AutotunerWorkConfiguration(
work_bounds=work_bounds,
work_load=work_load,
min_wg_size=min_wg_size,
max_wg_size=max_wg_size,
)
return work
[docs]
def compute_min_max_wg_size(
self, work_bounds, work_load, global_work_size, extra_parameters, extra_kwds
):
"""Default min and max workgroup size."""
min_wg_size = npw.ones(shape=work_bounds.work_dim, dtype=npw.int32)
max_wg_size = global_work_size.copy()
return (min_wg_size, max_wg_size)
[docs]
@abstractmethod
def compute_global_work_size(
self, work, local_work_size, extra_parameters, extra_kwds
):
"""
Compute aligned global_work_size from unaligned global_work_size
and local_work_size.
"""
pass
[docs]
@abstractmethod
def generate_kernel_src(
self,
global_work_size,
local_work_size,
extra_parameters,
extra_kwds,
tuning_mode,
dry_run,
):
"""
Generate kernel source code as a string.
"""
pass
@classmethod
def _check_build_configuration(cls, autotuner_config, build_opts):
"""Check autotuner_config and build options."""
check_instance(autotuner_config, KernelAutotunerConfig)
check_instance(build_opts, tuple)
[docs]
@classmethod
def check_cartesian_field(
cls,
field,
dtype=None,
size=None,
resolution=None,
compute_resolution=None,
nb_components=None,
ghosts=None,
min_ghosts=None,
max_ghosts=None,
domain=None,
topology=None,
):
check_instance(field, CartesianDiscreteScalarFieldView)
if (domain is not None) and (field.domain.domain is not domain):
msg = "Domain mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (topology is not None) and (field.topology.topology is not topology):
msg = "Topology mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (size is not None) and (field.npoints != size):
msg = "Size mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (resolution is not None) and any(field.resolution != resolution):
msg = "Resolution mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (compute_resolution is not None) and any(
field.compute_resolution != compute_resolution
):
msg = "Local resolution mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (dtype is not None) and (field.dtype != dtype):
msg = "dtype mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (nb_components is not None) and (field.nb_components != nb_components):
msg = "nb_components mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (ghosts is not None) and (field.ghosts != ghosts):
msg = "ghosts mismatch for dfield {}."
msg = msg.format(field.name)
raise RuntimeError(msg)
if (min_ghosts is not None) and npw.any(field.ghosts < min_ghosts):
msg = "Min ghosts mismatch for dfield {}, expected {} got {}."
msg = msg.format(field.name, min_ghosts, field.ghosts)
raise RuntimeError(msg)
if (max_ghosts is not None) and npw.any(field.ghosts > max_ghosts):
msg = "max ghosts mismatch for dfield {}, expected {} got {}."
msg = msg.format(field.name, max_ghosts, field.ghosts)
raise RuntimeError(msg)
[docs]
@classmethod
def check_cartesian_fields(cls, *fields, **kwds):
"""
Check that given fields are compatible (defined on the same domain)
By default, also compare dtypes, number of components and size.
Checks can be enabled or disabled by using check_[res,cres,size,components,dtype]
as boolean keyword arguments.
"""
check_instance(
fields, tuple, values=CartesianDiscreteScalarFieldView, minsize=1
)
check_resolution = kwds.pop("check_res", False)
check_compute_resolution = kwds.pop("check_cres", False)
check_size = kwds.pop("check_size", True)
check_nb_components = kwds.pop("check_components", True)
check_dtype = kwds.pop("check_dtype", True)
assert not kwds, f"Unused keyword arguments {kwds.keys()}."
domain = fields[0].domain
resolution = fields[0].compute_resolution
dtype = fields[0].dtype
size = fields[0].npoints
nb_components = fields[0].nb_components
for field in fields:
if field.domain.domain is not domain.domain:
msg = "Domain mismatch between dfield {} and dfield {}."
msg = msg.format(fields[0].name, field.name)
raise RuntimeError(msg)
if check_size and (field.npoints != size):
msg = "Size mismatch between dfield {} and dfield {}."
msg = msg.format(fields[0].name, field.name)
raise RuntimeError(msg)
if check_resolution and any(field.resolution != resolution):
msg = "Resolution mismatch between dfield {} and dfield {}."
msg = msg.format(fields[0].name, field.name)
raise RuntimeError(msg)
if check_compute_resolution and any(
field.compute_resolution != compute_resolution
):
msg = "Local resolution mismatch between dfield {} and dfield {}."
msg = msg.format(fields[0].name, field.name)
raise RuntimeError(msg)
if check_dtype and (field.dtype != dtype):
msg = "dtype mismatch between dfield {} and dfield {}."
msg = msg.format(fields[0].name, field.name)
raise RuntimeError(msg)
if check_nb_components and (field.nb_components != nb_components):
msg = "nb_components mismatch between dfield {} and dfield {}."
msg = msg.format(fields[0].name, field.name)
raise RuntimeError(msg)
[docs]
def mesh_info(self, name, mesh):
"""Create a MeshInfoStruct from a CartesianMesh."""
return MeshInfoStruct.create_from_mesh(
name=name, mesh=mesh, typegen=self.typegen
)[1]
[docs]
def output_mesh_info(self, field):
"""Create a MeshInfoStruct for an output DisreteCartesianField."""
name = f"{field.name}_out_field_mesh_info"
return self.mesh_info(name=name, mesh=field.mesh.mesh)
[docs]
class AutotunerParameterConfiguration:
"""Helper class for kernel autotuning to handle extra parameters."""
def __init__(self, **kwds):
super().__init__(**kwds)
self._param_names = ()
self._parameters = {}
def _get_parameter_names(self):
return self._param_names
def _get_parameters(self):
return self._parameters
param_names = property(_get_parameter_names)
parameters = property(_get_parameters)
[docs]
def iter_parameters(self):
param_names = self._param_names
param_values = tuple(self._parameters[pname] for pname in param_names)
param_iterator = it.product(*param_values)
for params in param_iterator:
extra_parameters = dict(zip(param_names, params))
yield extra_parameters
[docs]
class AutotunerWorkBoundsConfiguration:
"""Helper class for kernel autotuning to handle work bounds."""
def __init__(
self,
work_dim,
work_size,
min_work_load,
max_work_load,
max_device_work_dim,
max_device_work_group_size,
max_device_work_item_sizes,
preferred_work_group_size_multiple,
**kwds,
):
super().__init__(**kwds)
assert (
work_dim <= max_device_work_dim
), f"work_dim {work_dim} > {max_device_work_dim}"
work_dim = int(work_dim)
assert work_dim > 0
assert (
preferred_work_group_size_multiple > 0
), preferred_work_group_size_multiple
work_size = npw.asarray(work_size, dtype=npw.int32)
min_work_load = npw.asarray(min_work_load, dtype=npw.int32)
max_work_load = npw.asarray(max_work_load, dtype=npw.int32)
check_instance(work_size, npw.ndarray, dtype=npw.int32, size=work_dim)
check_instance(min_work_load, npw.ndarray, dtype=npw.int32, size=work_dim)
check_instance(max_work_load, npw.ndarray, dtype=npw.int32, size=work_dim)
assert (work_size > 0).all()
assert (min_work_load > 0).all()
assert (max_work_load >= min_work_load).all()
self._work_dim = work_dim
self._work_size = work_size
self._min_work_load = min_work_load
self._max_work_load = max_work_load
self._max_device_work_dim = int(max_device_work_dim)
self._max_device_work_group_size = int(max_device_work_group_size)
self._max_device_work_item_sizes = npw.asarray(
max_device_work_item_sizes[:work_dim], dtype=npw.int32
)
self._preferred_work_group_size_multiple = preferred_work_group_size_multiple
self._generate_work_loads()
def _get_work_dim(self):
return self._work_dim
def _get_work_size(self):
return self._work_size
def _get_min_work_load(self):
return self._min_work_load
def _get_max_work_load(self):
return self._max_work_load
def _get_max_device_work_dim(self):
return self._max_device_work_dim
def _get_max_device_work_group_size(self):
return self._max_device_work_group_size
def _get_max_device_work_item_sizes(self):
return self._max_device_work_item_sizes
def _get_preferred_work_group_size_multiple(self):
return self._preferred_work_group_size_multiple
work_dim = property(_get_work_dim)
work_size = property(_get_work_size)
min_work_load = property(_get_min_work_load)
max_work_load = property(_get_max_work_load)
max_device_work_dim = property(_get_max_device_work_dim)
max_device_work_group_size = property(_get_max_device_work_group_size)
max_device_work_item_sizes = property(_get_max_device_work_item_sizes)
preferred_work_group_size_multiple = property(
_get_preferred_work_group_size_multiple
)
def _generate_work_loads(self):
work_size = self.work_size
min_work_load, max_work_load = self.min_work_load, self.max_work_load
min_work_load = npw.minimum(min_work_load, work_size)
max_work_load = npw.minimum(max_work_load, work_size)
def _compute_pows(minw, maxw):
res = []
wl = minw
while wl < maxw:
res.append(wl)
wl = next_pow2(wl)
res.append(maxw)
res = tuple(res)
return res
work_loads = tuple(
_compute_pows(min_w, max_w)
for (min_w, max_w) in zip(min_work_load.tolist(), max_work_load.tolist())
)
work_loads = it.product(*work_loads)
self._work_loads = tuple(work_loads)
[docs]
def iter_work_loads(self):
for wl in self._work_loads:
yield npw.asarray(wl, dtype=npw.int32)
[docs]
class AutotunerWorkConfiguration:
__debug_filters = False
def __init__(
self, work_bounds, work_load, min_wg_size, max_wg_size, ordered_workload=True
):
check_instance(work_bounds, AutotunerWorkBoundsConfiguration)
check_instance(
work_load, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim
)
check_instance(
min_wg_size, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim
)
check_instance(
max_wg_size, npw.ndarray, dtype=npw.int32, size=work_bounds.work_dim
)
assert (min_wg_size >= 1).all(), f"min_wg_size = {min_wg_size}"
assert (min_wg_size <= max_wg_size).all(), f"{min_wg_size} > {max_wg_size}"
self._work_bounds = work_bounds
self._work_load = work_load
self._global_work_size = (work_bounds.work_size + work_load - 1) // work_load
self._filters = {}
self._filter_names = ()
self._min_wg_size = min_wg_size
self._max_wg_size = max_wg_size
self._local_work_size_generator = self._default_work_size_generator
self._generate_unfiltered_candidates()
self._load_default_filters(work_bounds, ordered_workload)
def _get_work_bounds(self):
return self._work_bounds
def _get_work_load(self):
return self._work_load
def _get_global_work_size(self):
return self._global_work_size
def _get_filters(self):
return self._filters
def _get_filter_names(self):
return self._filter_names
def _get_work_dim(self):
return self._work_bounds.work_dim
work_bounds = property(_get_work_bounds)
work_load = property(_get_work_load)
work_dim = property(_get_work_dim)
global_work_size = property(_get_global_work_size)
filters = property(_get_filters)
filter_names = property(_get_filter_names)
def _generate_unfiltered_candidates(self):
candidates = self._local_work_size_generator()
check_instance(candidates, tuple, values=npw.ndarray, minsize=1)
self._unfiltered_candidates = candidates
def _default_work_size_generator(self):
"""Default local_work_size generator."""
pows = []
size = 1
min_wi_size = self._min_wg_size
max_wi_size = self._max_wg_size
def _compute_pows(min_wi, max_wi):
res = []
wi = min_wi
while wi < max_wi:
res.append(wi)
wi = next_pow2(wi)
res.append(max_wi)
res = tuple(res)
return res
work_items = tuple(
_compute_pows(min_wi, max_wi)[::-1]
for (min_wi, max_wi) in zip(min_wi_size.tolist(), max_wi_size.tolist())
)
wi_candidates = it.product(*work_items)
return tuple(npw.asarray(wi, dtype=npw.int32) for wi in wi_candidates)
[docs]
def set_local_work_size_generator(self, fn):
"""
Set a custom local_work_size generator that
will generated a set of local_work_sizes to be
filtered.
"""
assert callable(fn)
self._local_work_size_generator = fn
self._generate_unfiltered_candidates()
[docs]
def iter_local_work_size(self):
"""Iterates over filtered work sizes."""
candidates = self._unfiltered_candidates
if self.__debug_filters:
msg = " *Initial workitems candidates:\n {}\n".format(
tuple(tuple(x) for x in candidates)
)
print(msg)
for fname in self.filter_names:
fn = self._filters[fname]
candidates = tuple(filter(fn, candidates))
if self.__debug_filters:
candidates, _ = it.tee(candidates)
msg = " *Filter {}:\n {}\n".format(fname, tuple(tuple(x) for x in _))
print(msg)
return candidates
[docs]
def push_filter(self, filter_name, filter_fn, **filter_kwds):
"""Push a named local_work_size filter with custom keywords."""
check_instance(filter_name, str)
assert callable(filter_fn)
if filter_name in self._filter_names:
msg = "Filter {} has already been registered."
msg = msg.format(filter_name)
raise RuntimeError(msg)
filter_fn = functools.partial(filter_fn, **filter_kwds)
self._filter_names += (filter_name,)
self._filters[filter_name] = filter_fn
def _load_default_filters(self, work_bounds, ordered_workload):
"""Load default local_work_size filters (mostly device limitations.)"""
self.push_filter(
f"max_device_work_item_sizes (default filter, max_work_item_sizes={work_bounds.max_device_work_item_sizes})",
self.max_wi_sizes_filter,
max_work_item_sizes=work_bounds.max_device_work_item_sizes,
)
self.push_filter(
f"max_device_work_group_size (default filter, max_device_work_group_size={work_bounds.max_device_work_group_size})",
self.max_wg_size_filter,
max_work_group_size=work_bounds.max_device_work_group_size,
)
if ordered_workload:
self.push_filter("ordered_workload (default)", self.ordered_workload_filter)
[docs]
@staticmethod
def max_wi_sizes_filter(local_work_size, max_work_item_sizes):
"""Filter out work items by size given a maximum size."""
return (local_work_size <= max_work_item_sizes).all()
[docs]
@staticmethod
def min_wi_sizes_filter(local_work_size, min_work_item_sizes):
"""Filter out work items by size given a minimum size."""
return (local_work_size >= min_work_item_sizes).all()
[docs]
@staticmethod
def max_wg_size_filter(local_work_size, max_work_group_size):
"""Filter out work items by workgroup size given a maximum workgroup size."""
return npw.prod(local_work_size, dtype=npw.int64) <= max_work_group_size
[docs]
@staticmethod
def ordered_workload_filter(local_work_size):
"""Filter out work items by decreasing dimensional sizes."""
oldval = local_work_size[0]
for val in local_work_size[1:]:
if val > oldval:
return False
oldval = val
return True
[docs]
@abstractmethod
def make_parameter(self, param):
pass
[docs]
@abstractmethod
def make_array_offset(self, dim):
pass
[docs]
@abstractmethod
def make_array_strides(self, dim):
pass
[docs]
@abstractmethod
def make_array_args(self, **arrays):
pass
[docs]
@abstractmethod
def make_dt(self, dtype):
pass